load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")

# Please reach out to tf-bridge-team@ before using the TF2XLA bridge.
package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = [
        ":__subpackages__",
    ],
)

cc_library(
    name = "legalize_tf",
    srcs = ["legalize_tf.cc"],
    hdrs = ["legalize_tf.h"],
    visibility = [
        "//learning/brain/google/xla:__pkg__",
        "//learning/brain/mlir/bridge:__pkg__",
        "//tensorflow/compiler/mlir/quantization/stablehlo:__pkg__",
        "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:__pkg__",
    ],
    deps = [
        ":device_type_proto_cc",
        "//tensorflow/compiler/jit:flags_headers",
        "//tensorflow/compiler/jit:shape_inference",
        "//tensorflow/compiler/mlir/tensorflow",
        "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
        "//tensorflow/compiler/mlir/tensorflow:error_util",
        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
        "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
        "//tensorflow/compiler/mlir/tensorflow:translate_utils",
        "//tensorflow/compiler/mlir/tensorflow/transforms:set_tpu_infeed_layout",
        "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
        "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util",
        "//tensorflow/compiler/mlir/tf2xla/api/v1:compile_tf_graph",
        "//tensorflow/compiler/mlir/tf2xla/internal:compilation_timer",
        "//tensorflow/compiler/mlir/tf2xla/internal:legalize_tf_mlir",
        "//tensorflow/compiler/mlir/tf2xla/internal:legalize_tf_to_hlo",
        "//tensorflow/compiler/mlir/tf2xla/internal:reproducer_proto_cc",
        "//tensorflow/compiler/tf2xla:layout_util",
        "//tensorflow/compiler/tf2xla:tf2xla_util",
        "//tensorflow/compiler/tf2xla:xla_helpers",
        "//tensorflow/core:core_cpu_base",
        "//tensorflow/core:framework",
        "//tensorflow/core:lib",
        "//tensorflow/core/platform:statusor",
        "//tensorflow/core/tpu:tpu_compile",
        "//tensorflow/core/tpu/kernels:tpu_compile_op_support",
        "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc",
        "//tensorflow/core/tpu/kernels:tpu_util_hdrs",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:variant",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@local_tsl//tsl/platform:error_logging",
        "@local_tsl//tsl/platform:protobuf",
        "@local_tsl//tsl/platform:status",
        "@local_tsl//tsl/platform:statusor",
        "@local_xla//xla:xla_proto_cc",
        "@local_xla//xla/client:compile_only_client",
        "@local_xla//xla/hlo/ir:hlo",
        "@local_xla//xla/mlir_hlo:hlo_dialect_registration",
        "@local_xla//xla/pjrt:compile_options_proto_cc",
        "@stablehlo//:register",
    ],
)

tf_cc_test(
    name = "legalize_tf_test",
    srcs = ["legalize_tf_test.cc"],
    deps = [
        ":device_type_proto_cc",
        ":legalize_tf",
        "//tensorflow/compiler/jit",
        "//tensorflow/compiler/mlir/tensorflow",
        "//tensorflow/compiler/mlir/tf2xla/internal:test_matchers",
        "//tensorflow/compiler/mlir/tf2xla/internal/utils:test_metadata_config",
        "//tensorflow/compiler/tf2xla:xla_compiler",
        "//tensorflow/compiler/tf2xla:xla_helpers",
        "//tensorflow/core:framework",
        "//tensorflow/core:lib",
        "//tensorflow/core:test",
        "//tensorflow/core:test_main",
        "//tensorflow/core/lib/monitoring:cell_reader",
        "//tensorflow/core/protobuf:for_core_protos_cc",
        "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
        "//tensorflow/core/tpu/kernels:tpu_compile_op_support",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/lib/monitoring:test_utils",
        "@local_tsl//tsl/platform:statusor",
        "@local_xla//xla/client:client_library",
        "@local_xla//xla/stream_executor:platform_manager",
    ],
)

tf_proto_library(
    name = "device_type_proto",
    srcs = ["device_type.proto"],
    cc_api_version = 2,
    visibility = [
        "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__",
    ],
)

cc_library(
    name = "cluster_tf",
    srcs = ["cluster_tf.cc"],
    hdrs = ["cluster_tf.h"],
    visibility = [
        "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__",
        "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__",
        "//tensorflow/compiler/mlir/tfrt:__pkg__",
        "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:__pkg__",
        "//tensorflow/compiler/tf2xla:__pkg__",
    ],
    deps = [
        ":device_type_proto_cc",
        "//tensorflow/compiler/mlir/tensorflow:attribute_utils",
        "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
        "//tensorflow/compiler/mlir/tensorflow:error_util",
        "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
        "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
        "//tensorflow/compiler/mlir/tf2xla/internal:clustering_bridge_passes",
        "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks",
        "//tensorflow/core:framework",
        "//tensorflow/core:lib_proto_parsing",
        "//tensorflow/core/platform:error_payloads",
        "//tensorflow/core/platform:errors",
        "//tensorflow/core/platform:stacktrace",
        "//tensorflow/core/platform:status",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:error_logging",
        "@local_tsl//tsl/platform:errors",
    ],
)

tf_cc_test(
    name = "cluster_tf_test",
    srcs = ["cluster_tf_test.cc"],
    data = [
        "testdata/empty_func.mlir",
        "testdata/invalid_executor.mlir",
        "testdata/outside_compilation.mlir",
    ],
    deps = [
        ":cluster_tf",
        "//tensorflow/compiler/mlir:register_common_dialects",
        "//tensorflow/compiler/mlir/tensorflow",
        "//tensorflow/compiler/mlir/tensorflow:attribute_utils",
        "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
        "//tensorflow/compiler/mlir/tf2xla/api/v2/testing:utils",
        "//tensorflow/core/lib/monitoring:cell_reader",
        "//tensorflow/core/platform:resource_loader",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Parser",
        "@local_tsl//tsl/platform:status",
        "@local_xla//xla/tsl/lib/core:status_test_util",
    ],
)

cc_library(
    name = "tf_dialect_to_executor",
    srcs = ["tf_dialect_to_executor.cc"],
    hdrs = ["tf_dialect_to_executor.h"],
    visibility = [
        "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__",
        "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__",
        "//tensorflow/compiler/mlir/tfrt:__pkg__",
        "//tensorflow/compiler/tf2xla:__pkg__",
    ],
    deps = [
        "//tensorflow/compiler/jit:flags_headers",
        "//tensorflow/compiler/mlir/tensorflow:bridge_logger",
        "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
        "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
        "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks",
        "//tensorflow/core:framework",
        "//tensorflow/core/platform:error_payloads",
        "//tensorflow/core/platform:status",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:Transforms",
        "@local_tsl//tsl/lib/monitoring:counter",
        "@local_tsl//tsl/platform:error_logging",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:status",
    ],
)

tf_cc_test(
    name = "tf_dialect_to_executor_test",
    srcs = ["tf_dialect_to_executor_test.cc"],
    data = [
        "testdata/empty_func.mlir",
        "testdata/func_with_dead_ops.mlir",
        "testdata/invalid_executor.mlir",
    ],
    deps = [
        ":tf_dialect_to_executor",
        "//tensorflow/compiler/mlir:register_common_dialects",
        "//tensorflow/compiler/mlir/tf2xla/api/v2/testing:utils",
        "//tensorflow/core/lib/monitoring:cell_reader",
        "//tensorflow/core/platform:resource_loader",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Parser",
        "@local_xla//xla/tsl/lib/core:status_test_util",
    ],
)

cc_library(
    name = "tf_executor_to_graph",
    srcs = [
        "tf_executor_to_graph.cc",
    ],
    hdrs = [
        "tf_executor_to_graph.h",
    ],
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/compiler/mlir:op_or_arg_name_mapper",
        "//tensorflow/compiler/mlir/tensorflow",
        "//tensorflow/compiler/mlir/tensorflow:convert_type",
        "//tensorflow/compiler/mlir/tensorflow:error_util",
        "//tensorflow/compiler/mlir/tensorflow:export_utils",
        "//tensorflow/compiler/mlir/tensorflow:translate_utils",
        "//tensorflow/compiler/mlir/tensorflow:verify_suitable_for_graph_export",
        "//tensorflow/compiler/mlir/tensorflow/translate:export_tf_dialect_op",
        "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags",
        "//tensorflow/compiler/mlir/utils:name_utils",
        "//tensorflow/core:core_cpu",
        "//tensorflow/core:framework",
        "//tensorflow/core:framework_internal",
        "//tensorflow/core:graph",
        "//tensorflow/core:lib",
        "//tensorflow/core:protos_all_cc",
        "//tensorflow/core/graph/regularization:util",
        "@com_google_absl//absl/base:core_headers",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:optional",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:Support",
        "@local_xla//xla:status_macros",
    ],
)
